import os
import logging
from dgl.dataloading import MultiLayerFullNeighborSampler
from dgl.dataloading import DataLoader
from torch.optim.lr_scheduler import MultiStepLR
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
import dgl
import pickle
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import average_precision_score, roc_auc_score, f1_score, accuracy_score
from scipy.io import loadmat
from tqdm import tqdm
from . import *
from .rgtan_lpa import load_lpa_subtensor
from .rgtan_model import RGTAN
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def initialize_centroids(features, k):
    """使用k-means++策略初始化聚类中心"""
    num_nodes = features.size(0)
    centroids = torch.zeros(k, features.size(1), device=features.device)
    
    # 随机选择第一个中心
    first_id = torch.randint(num_nodes, (1,)).item()
    centroids[0] = features[first_id]
    
    # 选择剩余的中心
    for i in range(1, k):
        # 计算到最近中心的距离
        distances = torch.min(torch.cdist(features, centroids[:i]), dim=1)[0]
        # 按概率选择下一个中心
        probabilities = distances / distances.sum()
        next_id = torch.multinomial(probabilities, 1).item()
        centroids[i] = features[next_id]
    
    return centroids

def check_convergence(centroids, prev_centroids, tol=1e-4):
    """检查聚类是否收敛"""
    return torch.norm(centroids - prev_centroids) < tol
def robust_node_clustering(features, k=2, temperature=0.1, max_iterations=10, labeled_features=None, labeled_classes=None):
    """基于论文的鲁棒节点聚类方法
    
    Args:
        features: 原始图的节点特征 [num_nodes, feature_dim]
        k: 聚类数量(默认2，对应二分类)
        temperature: 温度参数，控制软分配的软硬程度
        max_iterations: 最大迭代次数
        labeled_features: 有标签样本的特征 [num_labeled, feature_dim]
        labeled_classes: 有标签样本的标签 [num_labeled]
    
    Returns:
        tuple: (
            original_cluster_assignments: 原始图的聚类分配 [num_nodes, k]
            view1_cluster_assignments: 增强视图1的聚类分配 [num_nodes, k]
            view2_cluster_assignments: 增强视图2的聚类分配 [num_nodes, k]
            centroids: 聚类中心 [k, feature_dim]
        )
    """
    num_nodes = features.size(0)
    feature_dim = features.size(1)
    device = features.device

    
    # 聚类迭代过程不需要梯度，使用no_grad包裹
    with torch.no_grad():
        # 检查是否提供了有标签样本作为聚类中心
        if labeled_features is not None and labeled_classes is not None:
            # 使用有标签样本初始化聚类中心
            centroids = torch.zeros(k, feature_dim, device=device)
            
            # 按类别分组有标签样本
            for i in range(k):
                # 找到标签为i的样本
                class_indices = torch.where(labeled_classes == i)[0]
                if len(class_indices) > 0:
                    # 如果有该类的样本，计算这些样本的平均特征作为中心
                    centroids[i] = labeled_features[class_indices].mean(dim=0)
                else:
                    # 如果没有该类的样本，随机初始化
                    centroids[i] = torch.randn(feature_dim, device=device)
                    centroids[i] = F.normalize(centroids[i], p=2, dim=0)  # 归一化
                    
            # 规范化聚类中心 - 确保它们具有相同的范数
            norms = torch.norm(centroids, dim=1, keepdim=True)
            centroids = centroids / (norms + 1e-10)  # 避免除以零
            
        else:
            # 如果没有提供有标签样本，使用原始的k-means++初始化策略
            # 注意：只使用原始图特征进行中心初始化
            centroids = initialize_centroids(features, k)
        
        # 记录初始的聚类中心用于检查收敛
        prev_centroids = centroids.clone()
        
        # 只有在没有提供标签数据时才进行迭代优化
        if labeled_features is None or labeled_classes is None:
            # 迭代优化 - 完全不需要梯度
            for iter in range(max_iterations):
                # 计算每个节点到各个聚类中心的距离 - 只使用原始图特征
                distances = torch.cdist(features, centroids)  # [num_nodes, k]
                
                # 软分配 (使用Gumbel-Softmax进行可微分的聚类分配)
                logits = -distances / temperature
                cluster_assignments = F.gumbel_softmax(logits, tau=temperature, hard=False)
                
                # 更新聚类中心 - 只使用原始图特征
                new_centroids = torch.zeros_like(centroids)
                for j in range(k):
                    weights = cluster_assignments[:, j].unsqueeze(1)  # [num_nodes, 1]
                    if weights.sum() > 0:  # 避免除以零
                        new_centroids[j] = (features * weights).sum(0) / weights.sum()
                    else:
                        new_centroids[j] = centroids[j].clone()  # 保持原来的中心
                
                # 使用新的张量替代原有张量
                centroids = new_centroids
                    
                # 检查收敛
                if check_convergence(centroids, prev_centroids, tol=1e-4):
                    break
                    
                prev_centroids = centroids.clone()
    
    # 重新计算最终的聚类分配（在梯度环境下使用不同视图的features，保留梯度）
    # 为原始图特征计算聚类分配
    distances_original = torch.cdist(features, centroids)  # [num_nodes, k]
    logits_original = -distances_original / temperature
    original_cluster_assignments = F.gumbel_softmax(logits_original, tau=temperature, hard=False)
    
   
    view1_cluster_assignments = original_cluster_assignments

    
    view2_cluster_assignments = original_cluster_assignments

    # 计算聚类结果的统计信息
    with torch.no_grad():
        hard_assignments = torch.argmax(original_cluster_assignments, dim=1)
        num_class_0 = torch.sum(hard_assignments == 0).item()
        num_class_1 = torch.sum(hard_assignments == 1).item()
        total = num_class_0 + num_class_1
    
    return original_cluster_assignments, view1_cluster_assignments, view2_cluster_assignments, centroids
def compute_clustering_loss(features, cluster_assignments, centroids, epsilon=1e-6):
    features = F.normalize(features, p=2, dim=1)
    centroids = F.normalize(centroids, p=2, dim=1)
    
    with torch.no_grad():
        hard_assignments = torch.argmax(cluster_assignments, dim=1)
        pos_indices = torch.nonzero(hard_assignments == 1).squeeze(-1)
        neg_indices = torch.nonzero(hard_assignments == 0).squeeze(-1)
        num_pos = pos_indices.numel()
        num_neg = neg_indices.numel()
        total = num_pos + num_neg + epsilon
        pos_weight = num_neg / total if num_pos > 0 else 0.0
        neg_weight = num_pos / total if num_neg > 0 else 0.0

    distances = torch.cdist(features, centroids)  # [N, K]
    intra_positive_loss = torch.mean(distances[pos_indices, 1]) if num_pos > 0 else torch.tensor(0.0, device=features.device)
    intra_negative_loss = torch.mean(distances[neg_indices, 0]) if num_neg > 0 else torch.tensor(0.0, device=features.device)
    intra_loss =  intra_positive_loss +  intra_negative_loss

    centroid_dists = torch.pdist(centroids)
    inter_loss = -torch.mean(centroid_dists)

    expanded_centroids = torch.index_select(centroids, 0, hard_assignments)
    compactness = torch.mean(torch.sum((features - expanded_centroids) ** 2, dim=1))
    joint_reg = compactness / (torch.mean(centroid_dists) + epsilon)

    total_loss = 0.5 * intra_loss + 0.5 * inter_loss + 0.1 * joint_reg
    return total_loss, num_pos, num_neg

def nt_xent_loss(z_i, z_j, temperature=0.01):
            """
            NT-Xent Loss (Normalised Temperature-scaled Cross Entropy Loss)
            
            :param z_i: Tensor, representations of the first augmented view.
            :param z_j: Tensor, representations of the second augmented view.
            :param temperature: Float, temperature scaling factor for the loss function.
            """
            # Normalize the feature vectors
            z_i = F.normalize(z_i, dim=-1)
            z_j = F.normalize(z_j, dim=-1)
            
            # Concatenate the features from both views
            representations = torch.cat([z_i, z_j], dim=0)
            
            # Compute similarity matrix
            sim_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=-1)
            
            # Create labels for positive and negative pairs
            labels = torch.cat([torch.arange(z_i.size(0)).to(device) for _ in range(2)], dim=0)
            masks = labels[:, None] == labels[None, :]
            
            # Mask out self-similarity terms
            mask_diag = ~torch.eye(labels.size(0), dtype=torch.bool).to(device)
            sim_matrix = sim_matrix[mask_diag].view(labels.size(0), -1)
            masks = masks[mask_diag].view(labels.size(0), -1)
            
            # Compute the InfoNCE loss
            nominator = torch.exp(sim_matrix / temperature)[masks].view(labels.size(0), -1).sum(dim=-1)
            denominator = torch.sum(torch.exp(sim_matrix / temperature), dim=-1)
            loss = -torch.log(nominator / denominator).mean()
            
            return loss

def generate_contrastive_pairs(batch_nodes, labels):
            """
            根据给定的batch nodes生成正样本对和负样本对。
            
            :param batch_nodes: 当前批次中的节点索引列表
            :param labels: 节点标签
            :return: 一个包含(positive_pairs, negative_pairs)的元组
            """
            positive_pairs = []
            negative_pairs = []
            
            # 将CUDA张量转移到CPU，转换为NumPy数组
            if isinstance(labels, torch.Tensor) and labels.is_cuda:
                labels_cpu = labels.cpu().numpy()
            else:
                labels_cpu = labels

            # 确保batch_nodes也在CPU上
            if isinstance(batch_nodes, torch.Tensor) and batch_nodes.is_cuda:
                batch_nodes_cpu = batch_nodes.cpu().numpy()
            else:
                batch_nodes_cpu = batch_nodes

            # 将batch_nodes转换为集合以便快速查找
            batch_nodes_set = set(batch_nodes_cpu)
            batch_nodes_list = list(batch_nodes_cpu)
            
            # 获取batch内节点的标签
            batch_labels = labels_cpu[batch_nodes_cpu]

            for i, node in enumerate(batch_nodes_list):
                node_label = batch_labels[i]
                # 在batch内找同类别的节点
                same_class_indices = [j for j, label in enumerate(batch_labels) 
                                   if label == node_label and batch_nodes_list[j] != node]
                
                if same_class_indices:  # 如果存在同类别节点
                    pos_idx = np.random.choice(same_class_indices)
                    positive_pairs.append((i, pos_idx))  # 使用batch内的索引

                # 在batch内找不同类别的节点
                diff_class_indices = [j for j, label in enumerate(batch_labels) 
                                   if label != node_label]
                
                if diff_class_indices:  # 如果存在不同类别节点
                    neg_idx = np.random.choice(diff_class_indices)
                    negative_pairs.append((i, neg_idx))  # 使用batch内的索引
            
            return positive_pairs, negative_pairs

def get_augmented_view(edge_indexs, feat_data, aug_type, drop_rate=0.2):
    """获取指定类型的图增强视图，适配HOGRL的多层图结构
    Args:
        edge_indexs: 原始图的多层边索引
        feat_data: 节点特征
        aug_type: 增强类型 ['edge_drop', 'feat_drop', 'degree', 'pr', 'weighted_feat']
        drop_rate: 删除比例
    Returns:
        如果是边增强: 返回增强后的多层图结构
        如果是特征增强: 返回 (原始边索引, 增强后的特征)
    """
    if aug_type == 'feat_drop':
        # 特征删除
        feat_mask = torch.rand(feat_data.size(1)) > drop_rate
        feat_aug = feat_data.clone()
        feat_aug[:, ~feat_mask] = 0
        return edge_indexs, feat_aug
        
    elif aug_type == 'weighted_feat':
        # 加权特征删除
        node_deg = degree(edge_indexs[0][0][1])  # 使用第一个关系的主图计算节点度
        feat_weights = feature_drop_weights(feat_data, node_deg)
        feat_aug = drop_feature_weighted(feat_data, feat_weights, drop_rate)
        return edge_indexs, feat_aug
    
    # 以下是边增强的逻辑
    augmented_edge_indexs = []
    
    for i, edge_index in enumerate(edge_indexs):
        if aug_type == 'edge_drop':
            # 随机边删除
            edge_mask = torch.rand(edge_index[0].size(1)) > drop_rate
            edge_index_main = edge_index[0][:, edge_mask]
            edge_index_trees = [tree_edge[:, torch.rand(tree_edge.size(1)) > drop_rate] 
                              for tree_edge in edge_index[1]]
                
        elif aug_type == 'degree':
            # 基于度的加权边删除
            drop_weights = degree_drop_weights(edge_index[0])
            edge_index_main = drop_edge_weighted(edge_index[0], drop_weights, p=drop_rate)
            
            edge_index_trees = []
            for tree_edge in edge_index[1]:
                tree_weights = degree_drop_weights(tree_edge)
                edge_index_trees.append(drop_edge_weighted(tree_edge, tree_weights, p=drop_rate))
                
        elif aug_type == 'pr':
            # PageRank加权边删除
            drop_weights = pr_drop_weights(edge_index[0], aggr='sink', k=10)
            edge_index_main = drop_edge_weighted(edge_index[0], drop_weights, p=drop_rate)
            
            edge_index_trees = []
            for tree_edge in edge_index[1]:
                tree_weights = pr_drop_weights(tree_edge, aggr='sink', k=10)
                edge_index_trees.append(drop_edge_weighted(tree_edge, tree_weights, p=drop_rate))
                
        else:
            raise ValueError(f"不支持的增强类型: {aug_type}")
            
        augmented_edge_indexs.append([edge_index_main, edge_index_trees])
    
    return feat_data, augmented_edge_indexs



mu_rampup = True
consistency_rampup = None
def sigmoid_rampup(current, rampup_length):
    '''Exponential rampup from https://arxiv.org/abs/1610.02242'''
    if rampup_length == 0:
        return 1.0
    else:
        current = np.clip(current, 0.0, rampup_length)
        phase = 1.0 - current / rampup_length
        return float(np.exp(-5.0 * phase * phase))
def get_current_mu(epoch):
    mu = 1.5
    if mu_rampup:
        # Consistency ramp-up from https://arxiv.org/abs/1610.02242
        consistency_rampup = 500
        return mu * sigmoid_rampup(epoch, consistency_rampup)
    else:
        return mu

# 计算几何平均值Gmean
def geometric_mean(recall_0, recall_1):
    return np.sqrt(recall_0 * recall_1)

# 计算G-mean
def calculate_g_mean(y_true, y_pred):
    pos_indices = (y_true == 1)
    neg_indices = (y_true == 0)
    
    recall_pos = np.mean(y_pred[pos_indices] == y_true[pos_indices]) if np.any(pos_indices) else 0
    recall_neg = np.mean(y_pred[neg_indices] == y_true[neg_indices]) if np.any(neg_indices) else 0
    
    return geometric_mean(recall_neg, recall_pos)

# 添加 GradientAwareFocalLoss 类
class GradientAwareFocalLoss(nn.Module):
    def __init__(self, num_classes, k_percent=10, gamma_focal=2.0, gamma_ga=0.5, gamma_grad=1.0, use_softmax=True):
        super(GradientAwareFocalLoss, self).__init__()
        self.num_classes = num_classes
        self.k_percent = k_percent
        self.gamma_focal = gamma_focal
        self.gamma_ga = gamma_ga
        self.gamma_grad = gamma_grad
        self.use_softmax = use_softmax
        self.register_buffer('class_counts', torch.zeros(num_classes))
        self.register_buffer('class_weights', torch.ones(num_classes))

    def forward(self, inputs, targets):
        B, C = inputs.shape[:2]
        N = inputs.shape[2:].numel() * B

        probs = F.softmax(inputs, dim=1) if self.use_softmax else inputs
        probs = probs.permute(0, *range(2, inputs.dim()), 1).contiguous().view(-1, C)
        targets = targets.view(-1)
        pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
        ce_loss = -torch.log(pt + 1e-8)

        inputs_grad = inputs.detach().requires_grad_(True)
        probs_grad = F.softmax(inputs_grad, dim=1) if self.use_softmax else inputs_grad
        loss_grad = F.cross_entropy(probs_grad.view(-1, C), targets, reduction='none')
        grad_outputs = torch.ones_like(loss_grad)
        gradients = torch.autograd.grad(
            outputs=loss_grad,
            inputs=inputs_grad,
            grad_outputs=grad_outputs,
            create_graph=False,
            retain_graph=True
        )[0]

        gradients = gradients.permute(0, *range(2, gradients.dim()), 1).contiguous().view(-1, C)
        grad_magnitude = gradients.norm(p=2, dim=1)
        grad_weight = (grad_magnitude + 1e-8) ** self.gamma_grad

        num_topk = max(1, int(self.k_percent / 100 * N))
        _, topk_indices = torch.topk(ce_loss, num_topk, sorted=False)
        topk_targets = targets[topk_indices]
        current_counts = torch.bincount(topk_targets, minlength=self.num_classes).float()
        self.class_counts = 0.9 * self.class_counts + 0.1 * current_counts
        effective_counts = self.class_counts + 1e-8
        self.class_weights = (1.0 / effective_counts) ** (1.0 - self.gamma_ga)
        self.class_weights = self.class_weights / self.class_weights.sum() * C

        focal_weight = (1 - pt) ** self.gamma_focal
        class_weight = self.class_weights[targets]

        difficulty_weight = class_weight * grad_weight
        difficulty_weight = difficulty_weight / (difficulty_weight.mean())

        final_weight = focal_weight * difficulty_weight
        final_weight = final_weight / (final_weight.mean())

        loss = (final_weight * ce_loss).mean()
        return loss

# 添加 LPLLoss_advanced 类
class LPLLoss_advanced(nn.Module):
    def __init__(self, num_classes=2, pgd_nums=50, alpha=0.1, min_class_factor=3.0):
        super().__init__()
        self.num_classes = num_classes
        self.pgd_nums = pgd_nums
        self.alpha = alpha
        self.min_class_factor = min_class_factor
        self.criterion = nn.CrossEntropyLoss()
        
        self.register_buffer('class_counts', torch.zeros(num_classes))
        self.register_buffer('class_grad_mags', torch.zeros(num_classes))
        self.momentum = 0.9

    def update_statistics(self, logit, y):
        with torch.no_grad():
            batch_counts = torch.bincount(y, minlength=self.num_classes).float()
            self.class_counts = self.momentum * self.class_counts + (1 - self.momentum) * batch_counts
            
            grad_mags = torch.zeros(self.num_classes, device=logit.device)
            for c in range(self.num_classes):
                class_mask = (y == c)
                if torch.sum(class_mask) > 0:
                    class_logits = logit[class_mask]
                    class_targets = y[class_mask]
                    ce_loss = F.cross_entropy(class_logits, class_targets, reduction='none')
                    grad_mags[c] = ce_loss.mean().item()
            
            self.class_grad_mags = self.momentum * self.class_grad_mags + (1 - self.momentum) * grad_mags

    def compute_adaptive_params(self, logit, y):
        with torch.no_grad():
            self.update_statistics(logit, y)
            
            total_samples = torch.sum(self.class_counts)
            class_ratios = self.class_counts / (total_samples + 1e-8)
            
            minority_idx = torch.argmin(class_ratios).item()
            majority_idx = 1 - minority_idx
            
            imbalance_ratio = class_ratios[majority_idx] / (class_ratios[minority_idx] + 1e-8)
            imbalance_ratio_tensor = torch.tensor([imbalance_ratio], device=logit.device)
            imbalance_factor = torch.clamp(imbalance_ratio_tensor, 1.0, 10.0)
            
            grad_scale = F.softmax(self.class_grad_mags, dim=0)
            
            class_steps = torch.zeros(self.num_classes, device=logit.device, dtype=torch.long)
            class_alphas = torch.zeros(self.num_classes, device=logit.device, dtype=torch.float)
            
            max_steps = int(self.pgd_nums * 2.0)
            min_steps = max(1, int(self.pgd_nums * 0.5))
            
            for c in range(self.num_classes):
                freq_factor = torch.sqrt(1.0 / (class_ratios[c] + 1e-8))
                steps = min_steps + int((max_steps - min_steps) * freq_factor / (freq_factor + 1.0))
                class_steps[c] = steps
                
                alpha_base = self.alpha * (1.0 + grad_scale[c].item() * 2.0)
                
                if c == minority_idx:
                    alpha = alpha_base * min(5.0, imbalance_factor.item() ** 0.5)
                else:
                    alpha = alpha_base
                    
                class_alphas[c] = alpha
            
            if class_steps[minority_idx] < class_steps[majority_idx] * 1.5:
                class_steps[minority_idx] = int(class_steps[majority_idx] * 1.5)
            
            if class_alphas[minority_idx] < class_alphas[majority_idx] * self.min_class_factor:
                class_alphas[minority_idx] = class_alphas[majority_idx] * self.min_class_factor
            
            sample_steps = torch.zeros_like(y, dtype=torch.long)
            sample_alphas = torch.zeros_like(y, dtype=torch.float)
            
            for c in range(self.num_classes):
                class_mask = (y == c)
                sample_steps[class_mask] = class_steps[c]
                sample_alphas[class_mask] = class_alphas[c]
            
            with torch.enable_grad():
                logit_grad = logit.detach().clone().requires_grad_(True)
                loss = F.cross_entropy(logit_grad, y, reduction='none')
                
                grads = torch.autograd.grad(
                    outputs=loss.sum(),
                    inputs=logit_grad,
                    create_graph=False,
                    retain_graph=False
                )[0]
                
                sample_grad_norms = torch.norm(grads, p=2, dim=1)
                sample_difficulties = F.softmax(sample_grad_norms, dim=0)
                
                difficulty_scales = 0.8 + 0.7 * sample_difficulties / (torch.max(sample_difficulties) + 1e-8)
                
                sample_alphas = sample_alphas * difficulty_scales
                
                steps_difficulty_scales = 1.0 + 0.5 * sample_difficulties / (torch.max(sample_difficulties) + 1e-8)
                sample_steps = (sample_steps.float() * steps_difficulty_scales).long()
            
            return sample_steps, sample_alphas

    def compute_adv_sign(self, logit, y, sample_alphas):
        with torch.no_grad():
            logit_softmax = F.softmax(logit, dim=-1)
            y_onehot = F.one_hot(y, num_classes=self.num_classes)
            
            sum_class_logit = torch.matmul(y_onehot.permute(1, 0)*1.0, logit_softmax)
            sum_class_num = torch.sum(y_onehot, dim=0)
            
            sum_class_num = torch.where(sum_class_num == 0, 100, sum_class_num)
            mean_class_logit = torch.div(sum_class_logit, sum_class_num.reshape(-1, 1))
            
            grad = mean_class_logit - torch.eye(self.num_classes, device=logit.device)
            grad = torch.div(grad, torch.norm(grad, p=2, dim=0).reshape(-1, 1) + 1e-8)
            
            mean_class_p = torch.diag(mean_class_logit)
            mean_mask = sum_class_num > 0
            mean_class_thr = torch.mean(mean_class_p[mean_mask])
            sub = mean_class_thr - mean_class_p
            sign = sub.sign()
            
            alphas_expanded = sample_alphas.unsqueeze(1).expand(-1, self.num_classes)
            adv_logit = torch.index_select(grad, 0, y) * alphas_expanded * sign[y].unsqueeze(1)
            
            return adv_logit, sub

    def compute_eta(self, logit, y):
        with torch.no_grad():
            sample_steps, sample_alphas = self.compute_adaptive_params(logit, y)
            
            logit_clone = logit.clone()
            
            max_steps = torch.max(sample_steps).item()
            
            logit_steps = torch.zeros(
                [max_steps + 1, logit.shape[0], self.num_classes], device=logit.device)
            
            current_logit = logit.clone()
            logit_steps[0] = current_logit
            
            for i in range(1, max_steps + 1):
                adv_logit, _ = self.compute_adv_sign(current_logit, y, sample_alphas)
                current_logit = current_logit + adv_logit
                logit_steps[i] = current_logit
            
            logit_news = torch.zeros_like(logit)
            for i in range(logit.shape[0]):
                step = sample_steps[i].item()
                logit_news[i] = logit_steps[step, i]
            
            eta = logit_news - logit_clone
            
            return eta, sample_steps, sample_alphas

    def forward(self, models_or_logits, x=None, y=None, is_logits=False):
        if is_logits:
            logit = models_or_logits
        else:
            logit = models_or_logits(x)
        
        eta, sample_steps, sample_alphas = self.compute_eta(logit, y)
        
        logit_news = logit + eta
        
        loss_adv = self.criterion(logit_news, y)
        
        return loss_adv, logit, logit_news, sample_steps, sample_alphas

# 计算几何平均值Gmean
def geometric_mean(recall_0, recall_1):
    return np.sqrt(recall_0 * recall_1)

# 计算G-mean指标
def calculate_g_mean(y_true, y_pred):
    pos_indices = (y_true == 1)
    neg_indices = (y_true == 0)
    
    recall_pos = np.mean(y_pred[pos_indices] == y_true[pos_indices]) if np.any(pos_indices) else 0
    recall_neg = np.mean(y_pred[neg_indices] == y_true[neg_indices]) if np.any(neg_indices) else 0
    
    return geometric_mean(recall_neg, recall_pos)


def rgtan_main(feat_df, graph, train_idx, test_idx, labels, args, cat_features, neigh_features: pd.DataFrame, nei_att_head):
    # 设置随机种子为72
    args['seed'] = 64
    np.random.seed(args['seed'])
    torch.manual_seed(args['seed'])
    torch.cuda.manual_seed_all(args['seed'])
    
    # 设置日志
    log_dir = os.path.join(os.path.dirname(__file__), "..", "..", "logs")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log_file = os.path.join(log_dir, f"rgtan_log_{args.get('dataset', 'unknown')}_seed{args['seed']}.txt")
    logging.basicConfig(filename=log_file, level=logging.INFO, 
                        format='%(asctime)s - %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S')
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    logging.getLogger('').addHandler(console)
    
    device = args['device']
    logging.info(f'Device: {device}')
    graph = graph.to(device)
    oof_predictions = torch.from_numpy(
        np.zeros([len(feat_df), 2])).float().to(device)
    test_predictions = torch.from_numpy(
        np.zeros([len(feat_df), 2])).float().to(device)
    kfold = StratifiedKFold(
        n_splits=args['n_fold'], shuffle=True, random_state=args['seed'])

    y_target = labels.iloc[train_idx].values
    num_feat = torch.from_numpy(feat_df.values).float().to(device)
    cat_feat = {col: torch.from_numpy(feat_df[col].values).long().to(
        device) for col in cat_features}

    neigh_padding_dict = {}
    nei_feat = []
    if isinstance(neigh_features, pd.DataFrame):  # otherwise []
        # if null it is []
        nei_feat = {col: torch.from_numpy(neigh_features[col].values).to(torch.float32).to(
            device) for col in neigh_features.columns}
        
    y = labels
    labels = torch.from_numpy(y.values).long().to(device)
    loss_fn = nn.CrossEntropyLoss().to(device)

    # 初始化损失函数
    gradient_aware_focal = GradientAwareFocalLoss(num_classes=2,
                                              k_percent=10,
                                              gamma_focal=1,
                                              gamma_ga=0.8,
                                              gamma_grad=1,
                                              use_softmax=True).to(device)
    
    adaptive_lpl_loss = LPLLoss_advanced(
        num_classes=2,
        pgd_nums=20,
        alpha=0.001,
        min_class_factor=3
    ).to(device)

    # 添加最佳test指标的跟踪变量
    best_test_metrics = {
        'auc': 0,
        'f1': 0,
        'ap': 0,
        'acc1': 0,
        'acc0': 0,
        'gmean': 0,
        'epoch': 0
    }

    fixed_cluster_epochs = 10
    use_clustering_pseudo_labels = True
    use_original_pseudo_labels = True

    for fold, (trn_idx, val_idx) in enumerate(kfold.split(feat_df.iloc[train_idx], y_target)):
        logging.info(f'Training fold {fold + 1}')
        
        # 原始训练索引
        original_trn_ind = np.array(train_idx)[trn_idx]
        
        # 划分正负样本
        pos_samples = [i for i in original_trn_ind if y.iloc[i] == 1]
        neg_samples = [i for i in original_trn_ind if y.iloc[i] == 0]
        
        # 如果正样本或负样本数量不足，记录警告
        if len(pos_samples) == 0:
            logging.warning("训练集中没有正样本，无法选择一个正样本")
            pos_samples = []
        if len(neg_samples) == 0:
            logging.warning("训练集中没有负样本，无法选择一个负样本")
            neg_samples = []
        
        # 选择一个正样本和一个负样本
        selected_pos = [pos_samples[0]] if len(pos_samples) > 0 else []
        selected_neg = [neg_samples[0]] if len(neg_samples) > 0 else []
        
        # 新的训练集包含一个正样本和一个负样本，其余样本标签设为2
        labeled_samples = selected_pos + selected_neg
        
        # 将剩余的有标签样本转为无标签样本(label=2)
        remaining_pos = pos_samples[1:] if len(pos_samples) > 1 else []
        remaining_neg = neg_samples[1:] if len(neg_samples) > 1 else []
        unlabeled_samples = remaining_pos + remaining_neg
        
        # 创建无标签样本的mask，而不是修改原始标签
        unlabeled_mask = torch.zeros_like(labels)
        unlabeled_mask[unlabeled_samples] = 1
        
        # 将无标签样本加入到训练集,并确保每个batch都包含labeled样本
        batch_size = args['batch_size']
        num_labeled = len(labeled_samples)
        num_unlabeled_per_batch = batch_size - num_labeled
        
        # 将unlabeled samples分成多个batch
        num_full_batches = len(unlabeled_samples) // num_unlabeled_per_batch
        
        # 重新组织训练索引列表,确保每个batch都包含labeled samples
        final_trn_ind_list = []
        for i in range(num_full_batches):
            start_idx = i * num_unlabeled_per_batch
            end_idx = start_idx + num_unlabeled_per_batch
            batch_unlabeled = unlabeled_samples[start_idx:end_idx]
            final_trn_ind_list.extend(labeled_samples + batch_unlabeled)
            
        # 处理剩余的unlabeled samples
        remaining_start = num_full_batches * num_unlabeled_per_batch
        if remaining_start < len(unlabeled_samples):
            remaining_unlabeled = unlabeled_samples[remaining_start:]
            if len(remaining_unlabeled) > 0:
                final_trn_ind_list.extend(labeled_samples + remaining_unlabeled)
        
        logging.info(f'训练集正样本数: {len(selected_pos)}, 负样本数: {len(selected_neg)}, 无标签样本数: {len(unlabeled_samples)}')
        
        trn_ind = torch.tensor(final_trn_ind_list).long().to(device)
        val_ind = torch.from_numpy(np.array(train_idx)[val_idx]).long().to(device)
        
        logging.info(f'训练/验证/测试样本数: {len(trn_ind)}, {len(val_ind)}, {len(test_idx)}')

        train_sampler = MultiLayerFullNeighborSampler(args['n_layers'])
        train_dataloader = DataLoader(graph,
                                          trn_ind,
                                          train_sampler,
                                          device=device,
                                          use_ddp=False,
                                          batch_size=args['batch_size'],
                                          shuffle=False,  # 不需要shuffle,因为我们已经组织好了数据
                                          drop_last=False,
                                          num_workers=0
                                          )
        val_sampler = MultiLayerFullNeighborSampler(args['n_layers'])
        val_dataloader = DataLoader(graph,
                                        val_ind,
                                        val_sampler,
                                        use_ddp=False,
                                        device=device,
                                        batch_size=args['batch_size'],
                                        shuffle=True,
                                        drop_last=False,
                                        num_workers=0,
                                        )
        model = RGTAN(in_feats=feat_df.shape[1],
                      hidden_dim=args['hid_dim']//4,
                      n_classes=2,
                      heads=[4]*args['n_layers'],
                      activation=nn.PReLU(),
                      n_layers=args['n_layers'],
                      drop=args['dropout'],
                      device=device,
                      gated=args['gated'],
                      ref_df=feat_df,
                      cat_features=cat_feat,
                      neigh_features=nei_feat,
                      nei_att_head=nei_att_head).to(device)
        lr = args['lr'] * np.sqrt(args['batch_size']/1024)
        optimizer = optim.Adam(model.parameters(), lr=lr,
                               weight_decay=args['wd'])
        lr_scheduler = MultiStepLR(optimizer=optimizer, milestones=[
                                   4000, 12000], gamma=0.3)

        earlystoper = early_stopper(
            patience=args['early_stopping'], verbose=True)
        start_epoch, max_epochs = 0, 2000
        for epoch in range(start_epoch, args['max_epochs']):
            train_loss_list = []
            model.train()
            for step, (input_nodes, seeds, blocks) in enumerate(train_dataloader):
                batch_inputs, batch_work_inputs, batch_neighstat_inputs, batch_labels, lpa_labels = load_lpa_subtensor(
                    num_feat,
                    cat_feat,
                    nei_feat,  # 邻居特征
                    neigh_padding_dict,  # 邻居填充字典
                    labels,
                    seeds,
                    input_nodes,
                    device,
                    blocks
                )

                blocks = [block.to(device) for block in blocks]
                
                # 生成两个特征增强视图
                _, feat_aug1 = get_augmented_view(None, batch_inputs, aug_type='feat_drop', drop_rate=0.3)
                _, feat_aug2 = get_augmented_view(None, batch_inputs, aug_type='feat_drop', drop_rate=0.2)
                
                # 获取原始视图和两个增强视图的输出
                out_orig, h_orig = model(blocks, batch_inputs, lpa_labels, batch_work_inputs, batch_neighstat_inputs)
                out_aug1, h_aug1 = model(blocks, feat_aug1, lpa_labels, batch_work_inputs, batch_neighstat_inputs)
                out_aug2, h_aug2 = model(blocks, feat_aug2, lpa_labels, batch_work_inputs, batch_neighstat_inputs)
                
                # 在训练时使用mask来识别无标签样本
                mask = unlabeled_mask[seeds] == 1
                train_batch_logits = out_orig[~mask]
                batch_labels = labels[seeds][~mask]  # 使用原始标签

                # 分类损失
                classification_loss = F.nll_loss(train_batch_logits, batch_labels)
                
                # 对比学习损失
                batch_labeled = seeds[~mask]  # 获取有标签数据的索引
                positive_pairs, negative_pairs = generate_contrastive_pairs(batch_labeled, labels)
                
                if len(positive_pairs) > 0:
                    z_i_1 = h_aug1[torch.tensor([p[0] for p in positive_pairs], dtype=torch.long, device=device)]
                    z_j_1 = h_aug1[torch.tensor([p[1] for p in positive_pairs], dtype=torch.long, device=device)]
                    contrastive_loss_1 = nt_xent_loss(z_i_1, z_j_1)
                        
                    # 在第二个增强视图内计算对比损失
                    z_i_2 = h_aug2[torch.tensor([p[0] for p in positive_pairs], dtype=torch.long, device=device)]
                    z_j_2 = h_aug2[torch.tensor([p[1] for p in positive_pairs], dtype=torch.long, device=device)]
                    contrastive_loss_2 = nt_xent_loss(z_i_2, z_j_2)
                    contrastive_loss_2 = nt_xent_loss(z_i_2, z_j_2)
                        
                    # 取两个视图的平均对比损失
                    contrastive_loss = (contrastive_loss_1 + contrastive_loss_2) / 2
                else:
                    contrastive_loss = torch.tensor(0.0).to(device)
                
                # consistency loss
                consistency_loss = F.mse_loss(h_aug1, h_aug2)
                
                # 聚类和伪标签
                h_orig_unlabeled = h_orig[mask]
                h1_unlabeled = h_aug1[mask]
                h2_unlabeled = h_aug2[mask]
                
                labeled_features_orig = h_orig[~mask]
                labeled_classes = batch_labeled

                # 在前10个epoch中固定聚类中心为已有的正常和欺诈样本
                if epoch < fixed_cluster_epochs and use_clustering_pseudo_labels:
                    # 使用有标签数据初始化原始图的聚类中心，同时处理三个视图
                    cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2, centroids_orig = robust_node_clustering(
                        h_orig_unlabeled,  # 原始图特征
                        k=2, 
                        temperature=0.8,
                        max_iterations=10,
                        labeled_features=labeled_features_orig,  # 传入有标签样本特征
                        labeled_classes=labeled_classes       # 传入有标签样本类别
                    )
                elif use_clustering_pseudo_labels:
                    # 10个epoch后，让聚类算法自由寻找更好的聚类中心，同时处理三个视图
                    cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2, centroids_orig = robust_node_clustering(
                        h_orig_unlabeled,  # 原始图特征
                        k=2, 
                        temperature=0.8,
                        max_iterations=10
                    )

                # 创建合并的特征和分配
                all_features = torch.cat([h_orig_unlabeled, h1_unlabeled, h2_unlabeled], dim=0)
                all_assignments = torch.cat([cluster_assignments_orig, cluster_assignments_view1, cluster_assignments_view2], dim=0)
                # 计算统一的聚类损失
                clustering_loss, num_pos_all, num_neg_all = compute_clustering_loss(
                    all_features, 
                    all_assignments, 
                    centroids_orig  # 使用原始图的聚类中心
                )

                with torch.no_grad():
                    final_pseudo_labels_for_batch_unlabeled = torch.tensor([], dtype=torch.long, device=device)
                    
                    # 确定伪标签来源和计算逻辑
                    if use_original_pseudo_labels and use_clustering_pseudo_labels:
                        # 场景1: 模型输出 + 聚类结果 融合
                        orig_logits_unlabeled = out_orig[mask]
                        orig_probs_unlabeled = F.softmax(orig_logits_unlabeled, dim=1)

                        # 确定聚类0和聚类1哪个是多数 (负)，哪个是少数 (正)
                        temp_cluster_hard_labels = torch.argmax(cluster_assignments_orig, dim=1)
                        count_c0 = torch.sum(temp_cluster_hard_labels == 0).item()
                        count_c1 = torch.sum(temp_cluster_hard_labels == 1).item()
                        
                        aligned_cluster_probs = cluster_assignments_orig.clone()
                        if count_c0 < count_c1:
                            aligned_cluster_probs[:, 0] = cluster_assignments_orig[:, 1]
                            aligned_cluster_probs[:, 1] = cluster_assignments_orig[:, 0]
                        
                        combined_probs_unlabeled = (orig_probs_unlabeled + aligned_cluster_probs) / 2.0
                        final_pseudo_labels_for_batch_unlabeled = torch.argmax(combined_probs_unlabeled, dim=1)

                    elif use_clustering_pseudo_labels:
                        # 场景2: 仅使用聚类结果
                        temp_cluster_hard_labels = torch.argmax(cluster_assignments_orig, dim=1)
                        count_c0 = torch.sum(temp_cluster_hard_labels == 0).item()
                        count_c1 = torch.sum(temp_cluster_hard_labels == 1).item()

                        if count_c0 >= count_c1:
                            final_pseudo_labels_for_batch_unlabeled = temp_cluster_hard_labels
                        else:
                            final_pseudo_labels_for_batch_unlabeled = 1 - temp_cluster_hard_labels
                    
                    elif use_original_pseudo_labels:
                        # 场景3: 仅使用模型输出
                        consistent_high_conf_indices = torch.arange(final_pseudo_labels_for_batch_unlabeled.size(0), device=device)
                        consistent_pseudo_labels = final_pseudo_labels_for_batch_unlabeled

                # 使用伪标签计算损失
                if final_pseudo_labels_for_batch_unlabeled.numel() > 0:
                    consistent_high_conf_indices = torch.arange(final_pseudo_labels_for_batch_unlabeled.size(0), device=device)
                    consistent_pseudo_labels = final_pseudo_labels_for_batch_unlabeled
                    pseudo_logits_1 = out_aug1[mask][consistent_high_conf_indices]
                    pseudo_logits_2 = out_aug2[mask][consistent_high_conf_indices]

                    # 使用GradientAwareFocalLoss
                    pseudo_label_loss_1 = gradient_aware_focal(
                        pseudo_logits_1, 
                       consistent_pseudo_labels
                    )
                    pseudo_label_loss_2 = gradient_aware_focal(
                        pseudo_logits_2, 
                        consistent_pseudo_labels
                    )
                    pseudo_label_loss = (pseudo_label_loss_1 + pseudo_label_loss_2) / 2

                    # 使用LPLLoss
                    adap_lpl_loss_1, _, _, steps_1, alphas_1 = adaptive_lpl_loss(pseudo_logits_1, None, consistent_pseudo_labels, is_logits=True)
                    adap_lpl_loss_2, _, _, steps_2, alphas_2 = adaptive_lpl_loss(pseudo_logits_2, None, consistent_pseudo_labels, is_logits=True)
                    pseudo_lpl_loss = (adap_lpl_loss_1 + adap_lpl_loss_2) / 2
                else:
                    pseudo_label_loss = torch.tensor(0.0).to(device)
                    pseudo_lpl_loss = torch.tensor(0.0).to(device)

                # 获取当前mu值
                current_mu = get_current_mu(epoch)

                # 总损失
                train_loss = classification_loss + contrastive_loss + \
                            current_mu * consistency_loss + current_mu * clustering_loss + \
                            current_mu * pseudo_label_loss + current_mu * pseudo_lpl_loss

                # backward
                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()
                lr_scheduler.step()
                train_loss_list.append(train_loss.cpu().detach().numpy())

                if step % 10 == 0:
                    tr_batch_pred = torch.sum(torch.argmax(train_batch_logits.clone(
                    ).detach(), dim=1) == batch_labels) / batch_labels.shape[0]
                    score = torch.softmax(train_batch_logits.clone().detach(), dim=1)[
                        :, 1].cpu().numpy()
                        
                    # 计算正负样本准确率和G-mean
                    pred_labels = torch.argmax(train_batch_logits.clone().detach(), dim=1).cpu().numpy()
                    batch_labels_np = batch_labels.cpu().numpy()
                    
                    pos_indices = (batch_labels_np == 1)
                    neg_indices = (batch_labels_np == 0)
                    
                    train_acc1 = np.mean(pred_labels[pos_indices] == batch_labels_np[pos_indices]) if np.any(pos_indices) else 0.0
                    train_acc0 = np.mean(pred_labels[neg_indices] == batch_labels_np[neg_indices]) if np.any(neg_indices) else 0.0
                    
                    # 计算G-mean
                    train_gmean = calculate_g_mean(batch_labels_np, pred_labels)
                    
                    try:
                        log_msg = ('In epoch:{:03d}|batch:{:04d}, train_loss:{:4f}, '
                                  'train_ap:{:.4f}, train_acc:{:.4f}, train_auc:{:.4f}, '
                                  'train_acc1:{:.4f}, train_acc0:{:.4f}, train_gmean:{:.4f}')
                        
                        logging.info(log_msg.format(epoch, step,
                                                     np.mean(train_loss_list),
                                                     average_precision_score(batch_labels.cpu().numpy(), score),
                                                     tr_batch_pred.detach(),
                                                     roc_auc_score(batch_labels.cpu().numpy(), score),
                                                     train_acc1, train_acc0, train_gmean))
                    except Exception as e:
                        logging.error(f"Error calculating metrics: {e}")

            # mini-batch for validation
            val_loss_list = 0
            val_acc_list = 0
            val_all_list = 0
            val_batch_all_preds = []
            val_batch_all_labels = []
            model.eval()
            with torch.no_grad():
                for step, (input_nodes, seeds, blocks) in enumerate(val_dataloader):
                    batch_inputs, batch_work_inputs, batch_neighstat_inputs, batch_labels, lpa_labels = load_lpa_subtensor(num_feat, cat_feat, nei_feat, neigh_padding_dict, labels,
                                                                                                                       seeds, input_nodes, device, blocks)

                    blocks = [block.to(device) for block in blocks]
                    val_batch_logits, _ = model(
                        blocks, batch_inputs, lpa_labels, batch_work_inputs, batch_neighstat_inputs)
                    oof_predictions[seeds] = torch.exp(val_batch_logits)  # 转换回概率
                    mask = batch_labels == 2
                    val_batch_logits = val_batch_logits[~mask]
                    batch_labels = batch_labels[~mask]
                    val_loss_list = val_loss_list + \
                        F.nll_loss(val_batch_logits, batch_labels)  # 使用nll_loss
                    val_batch_pred = torch.sum(torch.argmax(
                        val_batch_logits, dim=1) == batch_labels) / torch.tensor(batch_labels.shape[0])
                    val_acc_list = val_acc_list + val_batch_pred * \
                        torch.tensor(batch_labels.shape[0])
                    val_all_list = val_all_list + batch_labels.shape[0]
                    
                    # 收集预测和标签用于计算整体指标
                    pred_labels = torch.argmax(val_batch_logits, dim=1).cpu().numpy()
                    batch_labels_np = batch_labels.cpu().numpy()
                    val_batch_all_preds.append(pred_labels)
                    val_batch_all_labels.append(batch_labels_np)
                    
                    if step % 10 == 0:
                        score = torch.exp(val_batch_logits)[:, 1].cpu().numpy()  # 转换回概率
                        
                        # 计算正负样本准确率
                        pos_indices = (batch_labels_np == 1)
                        neg_indices = (batch_labels_np == 0)
                        
                        val_acc1 = np.mean(pred_labels[pos_indices] == batch_labels_np[pos_indices]) if np.any(pos_indices) else 0.0
                        val_acc0 = np.mean(pred_labels[neg_indices] == batch_labels_np[neg_indices]) if np.any(neg_indices) else 0.0
                        
                        # 计算G-mean
                        val_gmean = calculate_g_mean(batch_labels_np, pred_labels)
                        
                        try:
                            log_msg = ('In epoch:{:03d}|batch:{:04d}, val_loss:{:4f}, val_ap:{:.4f}, '
                                      'val_acc:{:.4f}, val_auc:{:.4f}, val_acc1:{:.4f}, val_acc0:{:.4f}, val_gmean:{:.4f}')
                            
                            logging.info(log_msg.format(epoch,
                                                          step,
                                                          val_loss_list/val_all_list,
                                                          average_precision_score(batch_labels_np, score),
                                                          val_batch_pred.detach(),
                                                          roc_auc_score(batch_labels_np, score),
                                                          val_acc1, val_acc0, val_gmean))
                        except Exception as e:
                            logging.error(f"Error calculating validation metrics: {e}")
                
                # 计算整体验证集指标
                if len(val_batch_all_labels) > 0 and len(val_batch_all_preds) > 0:
                    all_val_labels = np.concatenate(val_batch_all_labels)
                    all_val_preds = np.concatenate(val_batch_all_preds)
                    
                    pos_indices = (all_val_labels == 1)
                    neg_indices = (all_val_labels == 0)
                    
                    val_acc1 = np.mean(all_val_preds[pos_indices] == all_val_labels[pos_indices]) if np.any(pos_indices) else 0.0
                    val_acc0 = np.mean(all_val_preds[neg_indices] == all_val_labels[neg_indices]) if np.any(neg_indices) else 0.0
                    val_gmean = calculate_g_mean(all_val_labels, all_val_preds)
                    
                    logging.info(f'Epoch {epoch} validation metrics - ACC1: {val_acc1:.4f}, ACC0: {val_acc0:.4f}, G-mean: {val_gmean:.4f}')

            # val_acc_list/val_all_list, model)
            earlystoper.earlystop(val_loss_list/val_all_list, model)
            if earlystoper.is_earlystop:
                logging.info("Early Stopping!")
                break
        logging.info("Best val_loss is: {:.7f}".format(earlystoper.best_cv))
        test_ind = torch.from_numpy(np.array(test_idx)).long().to(device)
        test_sampler = MultiLayerFullNeighborSampler(args['n_layers'])
        test_dataloader = DataLoader(graph,
                                         test_ind,
                                         test_sampler,
                                         use_ddp=False,
                                         device=device,
                                         batch_size=args['batch_size'],
                                         shuffle=True,
                                         drop_last=False,
                                         num_workers=0,
                                         )
        b_model = earlystoper.best_model.to(device)
        b_model.eval()
        test_batch_all_preds = []
        test_batch_all_labels = []
        test_batch_all_scores = []
        
        with torch.no_grad():
            for step, (input_nodes, seeds, blocks) in enumerate(test_dataloader):
                batch_inputs, batch_work_inputs, batch_neighstat_inputs, batch_labels, lpa_labels = load_lpa_subtensor(
                    num_feat,
                    cat_feat,
                    nei_feat,  # 邻居特征
                    neigh_padding_dict,  # 邻居填充字典
                    labels,
                    seeds,
                    input_nodes,
                    device,
                    blocks
                )

                blocks = [block.to(device) for block in blocks]
                test_batch_logits, _ = b_model(blocks, batch_inputs, lpa_labels, batch_work_inputs, batch_neighstat_inputs)

                test_predictions[seeds] = test_batch_logits
                
                # 收集预测和标签用于计算整体指标
                pred_labels = torch.argmax(test_batch_logits, dim=1).cpu().numpy()
                batch_labels_np = batch_labels.cpu().numpy()
                test_scores = torch.softmax(test_batch_logits, dim=1)[:, 1].cpu().numpy()
                
                test_batch_all_preds.append(pred_labels)
                test_batch_all_labels.append(batch_labels_np)
                test_batch_all_scores.append(test_scores)
                
                if step % 10 == 0:
                    logging.info('In test batch:{:04d}'.format(step))
                    
        # 计算当前epoch的test指标
        if len(test_batch_all_labels) > 0 and len(test_batch_all_preds) > 0:
            all_test_labels = np.concatenate(test_batch_all_labels)
            all_test_preds = np.concatenate(test_batch_all_preds)
            all_test_scores = np.concatenate(test_batch_all_scores)
            
            mask = all_test_labels != 2
            all_test_labels = all_test_labels[mask]
            all_test_preds = all_test_preds[mask]
            all_test_scores = all_test_scores[mask]
            
            pos_indices = (all_test_labels == 1)
            neg_indices = (all_test_labels == 0)
            
            current_test_metrics = {}
            current_test_metrics['auc'] = roc_auc_score(all_test_labels, all_test_scores)
            current_test_metrics['f1'] = f1_score(all_test_labels, all_test_preds, average="macro")
            current_test_metrics['ap'] = average_precision_score(all_test_labels, all_test_scores)
            current_test_metrics['acc1'] = np.mean(all_test_preds[pos_indices] == all_test_labels[pos_indices]) if np.any(pos_indices) else 0.0
            current_test_metrics['acc0'] = np.mean(all_test_preds[neg_indices] == all_test_labels[neg_indices]) if np.any(neg_indices) else 0.0
            current_test_metrics['gmean'] = calculate_g_mean(all_test_labels, all_test_preds)
            
            # 输出当前epoch的test指标
            logging.info(f'Epoch {epoch} test metrics:')
            logging.info(f'AUC: {current_test_metrics["auc"]:.4f}')
            logging.info(f'F1: {current_test_metrics["f1"]:.4f}')
            logging.info(f'AP: {current_test_metrics["ap"]:.4f}')
            logging.info(f'ACC1: {current_test_metrics["acc1"]:.4f}')
            logging.info(f'ACC0: {current_test_metrics["acc0"]:.4f}')
            logging.info(f'G-mean: {current_test_metrics["gmean"]:.4f}')
            
            # 更新最佳指标
            if current_test_metrics['gmean'] > best_test_metrics['gmean']:
                best_test_metrics = current_test_metrics
                best_test_metrics['epoch'] = epoch

    # 在最后输出最佳test指标
    logging.info("\nBest test metrics (at epoch {}):".format(best_test_metrics['epoch']))
    logging.info("Best test AUC: {:.4f}".format(best_test_metrics['auc']))
    logging.info("Best test F1: {:.4f}".format(best_test_metrics['f1']))
    logging.info("Best test AP: {:.4f}".format(best_test_metrics['ap']))
    logging.info("Best test ACC1: {:.4f}".format(best_test_metrics['acc1']))
    logging.info("Best test ACC0: {:.4f}".format(best_test_metrics['acc0']))
    logging.info("Best test G-mean: {:.4f}".format(best_test_metrics['gmean']))
    
    logging.info("\nFinal test metrics:")
    mask = y_target == 2
    y_target[mask] = 0
    my_ap = average_precision_score(y_target, torch.softmax(
        oof_predictions, dim=1).cpu()[train_idx, 1])
    logging.info("NN out of fold AP is: {:.4f}".format(my_ap))
    b_models, val_gnn_0, test_gnn_0 = earlystoper.best_model.to(
        'cpu'), oof_predictions, test_predictions

    test_score = torch.softmax(test_gnn_0, dim=1)[test_idx, 1].cpu().numpy()
    y_target = labels[test_idx].cpu().numpy()
    test_score1 = torch.argmax(test_gnn_0, dim=1)[test_idx].cpu().numpy()

    mask = y_target != 2
    test_score = test_score[mask]
    y_target = y_target[mask]
    test_score1 = test_score1[mask]

    # 计算最终测试指标
    test_auc = roc_auc_score(y_target, test_score)
    test_f1 = f1_score(y_target, test_score1, average="macro")
    test_ap = average_precision_score(y_target, test_score)
    
    # 计算正负样本准确率
    pos_indices = (y_target == 1)
    neg_indices = (y_target == 0)
    
    test_acc1 = np.mean(test_score1[pos_indices] == y_target[pos_indices]) if np.any(pos_indices) else 0.0
    test_acc0 = np.mean(test_score1[neg_indices] == y_target[neg_indices]) if np.any(neg_indices) else 0.0
    
    # 计算G-mean
    test_gmean = calculate_g_mean(y_target, test_score1)
    
    logging.info("Final test AUC: {:.4f}".format(test_auc))
    logging.info("Final test F1: {:.4f}".format(test_f1))
    logging.info("Final test AP: {:.4f}".format(test_ap))
    logging.info("Final test ACC1: {:.4f}".format(test_acc1))
    logging.info("Final test ACC0: {:.4f}".format(test_acc0))
    logging.info("Final test G-mean: {:.4f}".format(test_gmean))


def loda_rgtan_data(dataset: str, test_size: float):
    # prefix = "./antifraud/data/"
    prefix = "data/"
    if dataset == 'S-FFSD':
        cat_features = ["Target", "Location", "Type"]

        
        df = pd.read_csv(prefix + "S-FFSDneofull.csv")
        df = df.loc[:, ~df.columns.str.contains('Unnamed')]
        #####
        neigh_features = []
        #####
        data = df[df["Labels"] <= 2]
        data = data.reset_index(drop=True)
        out = []
        alls = []
        allt = []
        pair = ["Source", "Target", "Location", "Type"]
        for column in pair:
            src, tgt = [], []
            edge_per_trans = 3
            for c_id, c_df in tqdm(data.groupby(column), desc=column):
                c_df = c_df.sort_values(by="Time")
                df_len = len(c_df)
                sorted_idxs = c_df.index
                src.extend([sorted_idxs[i] for i in range(df_len)
                            for j in range(edge_per_trans) if i + j < df_len])
                tgt.extend([sorted_idxs[i+j] for i in range(df_len)
                            for j in range(edge_per_trans) if i + j < df_len])
            alls.extend(src)
            allt.extend(tgt)
        alls = np.array(alls)
        allt = np.array(allt)
        g = dgl.graph((alls, allt))
        cal_list = ["Source", "Target", "Location", "Type"]
        for col in cal_list:
            le = LabelEncoder()
            data[col] = le.fit_transform(data[col].apply(str).values)
        feat_data = data.drop("Labels", axis=1)
        labels = data["Labels"]

        #######
        g.ndata['label'] = torch.from_numpy(
            labels.to_numpy()).to(torch.long)
        g.ndata['feat'] = torch.from_numpy(
            feat_data.to_numpy()).to(torch.float32)
        #######

        graph_path = prefix+"graph-{}.bin".format(dataset)
        dgl.data.utils.save_graphs(graph_path, [g])
        index = list(range(len(labels)))

        train_idx, test_idx, y_train, y_test = train_test_split(index, labels, stratify=labels, test_size=0.6,
                                                                random_state=72, shuffle=True)
        feat_neigh = pd.read_csv(
            prefix + "S-FFSD_neigh_feat.csv")
        print("neighborhood feature loaded for nn input.")
        neigh_features = feat_neigh

    elif dataset == 'yelp':
        cat_features = []
        neigh_features = []
        data_file = loadmat(prefix + 'YelpChi.mat')
        labels = pd.DataFrame(data_file['label'].flatten())[0]
        feat_data = pd.DataFrame(data_file['features'].todense().A)
        # load the preprocessed adj_lists
        with open(prefix + 'yelp_homo_adjlists.pickle', 'rb') as file:
            homo = pickle.load(file)
        file.close()
        index = list(range(len(labels)))
        train_idx, test_idx, y_train, y_test = train_test_split(index, labels, stratify=labels, test_size=test_size,
                                                                random_state=72, shuffle=True)
        src = []
        tgt = []
        for i in homo:
            for j in homo[i]:
                src.append(i)
                tgt.append(j)
        src = np.array(src)
        tgt = np.array(tgt)
        g = dgl.graph((src, tgt))
        g.ndata['label'] = torch.from_numpy(labels.to_numpy()).to(torch.long)
        g.ndata['feat'] = torch.from_numpy(
            feat_data.to_numpy()).to(torch.float32)
        graph_path = prefix + "graph-{}.bin".format(dataset)
        dgl.data.utils.save_graphs(graph_path, [g])

        try:
            feat_neigh = pd.read_csv(
                prefix + "yelp_neigh_feat.csv")
            print("neighborhood feature loaded for nn input.")
            neigh_features = feat_neigh
        except:
            print("no neighbohood feature used.")

    elif dataset == 'amazon':
        cat_features = []
        neigh_features = []
        data_file = loadmat(prefix + 'Amazon.mat')
        labels = pd.DataFrame(data_file['label'].flatten())[0]
        feat_data = pd.DataFrame(data_file['features'].todense().A)
        # load the preprocessed adj_lists
        with open(prefix + 'amz_homo_adjlists.pickle', 'rb') as file:
            homo = pickle.load(file)
        file.close()
        index = list(range(3305, len(labels)))
        train_idx, test_idx, y_train, y_test = train_test_split(index, labels[3305:], stratify=labels[3305:],
                                                                test_size=test_size, random_state=72, shuffle=True)
        src = []
        tgt = []
        for i in homo:
            for j in homo[i]:
                src.append(i)
                tgt.append(j)
        src = np.array(src)
        tgt = np.array(tgt)
        g = dgl.graph((src, tgt))
        g.ndata['label'] = torch.from_numpy(labels.to_numpy()).to(torch.long)
        g.ndata['feat'] = torch.from_numpy(
            feat_data.to_numpy()).to(torch.float32)
        graph_path = prefix + "graph-{}.bin".format(dataset)
        dgl.data.utils.save_graphs(graph_path, [g])
        try:
            feat_neigh = pd.read_csv(
                prefix + "amazon_neigh_feat.csv")
            print("neighborhood feature loaded for nn input.")
            neigh_features = feat_neigh
        except:
            print("no neighbohood feature used.")

    return feat_data, labels, train_idx, test_idx, g, cat_features, neigh_features
